{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c157c036-9cab-4253-aa0f-42717ba30cb7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:53:45.252189Z",
     "iopub.status.busy": "2024-12-24T13:53:45.251528Z",
     "iopub.status.idle": "2024-12-24T13:53:45.261167Z",
     "shell.execute_reply": "2024-12-24T13:53:45.259014Z",
     "shell.execute_reply.started": "2024-12-24T13:53:45.252132Z"
    }
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f43c9c70-4df8-4d73-9a95-d5bef91ddb53",
   "metadata": {},
   "source": [
    "\n",
    "- Scaling and evaluating sparse autoencoders\n",
    "    - https://openai.com/index/extracting-concepts-from-gpt-4/\n",
    "    - https://cdn.openai.com/papers/sparse-autoencoders.pdf\n",
    "\n",
    "For the sparsity constraint, we use a k-sparse constraint: **only the $k$ largest activations in $h$ are\n",
    "retained**, while **the rest are set to zero** (Makhzani et al., 2013; Gao et al., 2024). This approach avoids\n",
    "issues such as shrinkage, where L1 regularisation can cause feature activations to be systematically\n",
    "lower than their true values, potentially leading to suboptimal representations shrinkage, (Wright\n",
    "et al., 2024; Rajamanoharan et al., 2024).\n",
    "\n",
    "```\n",
    "def k_sparse(self, x):\n",
    "    # 实现k-sparse约束\n",
    "    topk, indices = torch.topk(x, self.k, dim=1)\n",
    "    mask = torch.zeros_like(x).scatter_(1, indices, 1)\n",
    "    return x * mask\n",
    "\n",
    "def k_sparse(self, x):\n",
    "    topk_values, _ = torch.topk(x, self.k, dim=1)\n",
    "    k_smallest = topk_values[:, -1:]\n",
    "    mask = (h >= k_smallest).float()\n",
    "    return x * mask\n",
    "```\n",
    "\n",
    "- topk 与 relu/sigmiod 这些一样也可以是为一种激活函数；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ca5d0267-8de2-4389-8a30-60d5caea99f6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:54:44.405167Z",
     "iopub.status.busy": "2024-12-24T13:54:44.404493Z",
     "iopub.status.idle": "2024-12-24T13:54:44.417928Z",
     "shell.execute_reply": "2024-12-24T13:54:44.415488Z",
     "shell.execute_reply.started": "2024-12-24T13:54:44.405113Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://miro.medium.com/v2/resize:fit:828/format:webp/1*l9HHMhtm7LO_YKbRpWGlUA.png\" width=\"400\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# src.shape == index.shape\n",
    "Image(url='https://miro.medium.com/v2/resize:fit:828/format:webp/1*l9HHMhtm7LO_YKbRpWGlUA.png', width=400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5cb3a3fa-5519-4794-b9fc-129bbb53d2bf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:05:17.010339Z",
     "iopub.status.busy": "2024-12-24T13:05:17.009959Z",
     "iopub.status.idle": "2024-12-24T13:05:17.032295Z",
     "shell.execute_reply": "2024-12-24T13:05:17.030067Z",
     "shell.execute_reply.started": "2024-12-24T13:05:17.010308Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7eee165d9610>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f342eba2-b039-4e64-b896-93e5cda4dc59",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:06:18.472309Z",
     "iopub.status.busy": "2024-12-24T13:06:18.471653Z",
     "iopub.status.idle": "2024-12-24T13:06:18.482736Z",
     "shell.execute_reply": "2024-12-24T13:06:18.480501Z",
     "shell.execute_reply.started": "2024-12-24T13:06:18.472255Z"
    }
   },
   "outputs": [],
   "source": [
    "x = torch.tensor([[0.1, 0.5, 0.3, 0.8, 0.2, 0.4], \n",
    "                  [0.2, 0.9, 0.3, 0.1, 0.7, 0.4]], \n",
    "                 requires_grad=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c31b58e4-0ee7-4d44-8cf1-ea5849a8a025",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:06:19.468895Z",
     "iopub.status.busy": "2024-12-24T13:06:19.468262Z",
     "iopub.status.idle": "2024-12-24T13:06:19.477942Z",
     "shell.execute_reply": "2024-12-24T13:06:19.475583Z",
     "shell.execute_reply.started": "2024-12-24T13:06:19.468844Z"
    }
   },
   "outputs": [],
   "source": [
    "topk_values, topk_indices = torch.topk(x, k=2, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "bc46fabe-afad-4cf3-adef-695b595e7884",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:06:21.634447Z",
     "iopub.status.busy": "2024-12-24T13:06:21.633830Z",
     "iopub.status.idle": "2024-12-24T13:06:21.648472Z",
     "shell.execute_reply": "2024-12-24T13:06:21.646291Z",
     "shell.execute_reply.started": "2024-12-24T13:06:21.634397Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8000, 0.5000],\n",
       "        [0.9000, 0.7000]], grad_fn=<TopkBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7ab43d47-7f6e-4ee4-a06d-5e2314a563e3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:05:48.259116Z",
     "iopub.status.busy": "2024-12-24T13:05:48.258442Z",
     "iopub.status.idle": "2024-12-24T13:05:48.271936Z",
     "shell.execute_reply": "2024-12-24T13:05:48.269535Z",
     "shell.execute_reply.started": "2024-12-24T13:05:48.259061Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[3, 1],\n",
       "        [1, 4]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "376a5a27-9ae4-4f76-900b-d8c7c15558d9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-24T13:05:05.052348Z",
     "iopub.status.busy": "2024-12-24T13:05:05.051694Z",
     "iopub.status.idle": "2024-12-24T13:05:05.075451Z",
     "shell.execute_reply": "2024-12-24T13:05:05.073380Z",
     "shell.execute_reply.started": "2024-12-24T13:05:05.052295Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 1., 0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0., 1., 0.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.zeros_like(x).scatter_(1, topk_indices, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5d4549f3-0d36-4e4b-b401-db19f964c7a9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-23T15:37:34.802569Z",
     "iopub.status.busy": "2024-12-23T15:37:34.801882Z",
     "iopub.status.idle": "2024-12-23T15:37:34.813068Z",
     "shell.execute_reply": "2024-12-23T15:37:34.810814Z",
     "shell.execute_reply.started": "2024-12-23T15:37:34.802520Z"
    }
   },
   "outputs": [],
   "source": [
    "mask = torch.zeros_like(x).scatter_(1, topk_indices, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e2a06279-8659-4d14-a938-aae9b33401d3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-23T15:37:43.915515Z",
     "iopub.status.busy": "2024-12-23T15:37:43.914835Z",
     "iopub.status.idle": "2024-12-23T15:37:43.930581Z",
     "shell.execute_reply": "2024-12-23T15:37:43.928381Z",
     "shell.execute_reply.started": "2024-12-23T15:37:43.915466Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1000, 0.5000, 0.3000, 0.8000, 0.2000, 0.4000],\n",
       "        [0.2000, 0.9000, 0.3000, 0.1000, 0.7000, 0.4000]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f11fad78-60c7-49b4-ae3e-9ef9b49b2388",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-23T15:37:38.126537Z",
     "iopub.status.busy": "2024-12-23T15:37:38.125859Z",
     "iopub.status.idle": "2024-12-23T15:37:38.140867Z",
     "shell.execute_reply": "2024-12-23T15:37:38.138981Z",
     "shell.execute_reply.started": "2024-12-23T15:37:38.126488Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.5000, 0.0000, 0.8000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.9000, 0.0000, 0.0000, 0.7000, 0.0000]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x * mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1413f4b3-9bb9-4091-a4ee-6a70c7ca9fe9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
